-
Notifications
You must be signed in to change notification settings - Fork 14.4k
[libc++] Optimizations for uniform_int_distribution #140161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
@llvm/pr-subscribers-libcxx Author: None (LRFLEW) ChangesI noticed that the implementation for In libc++, While analyzing the definitions of the values that need to be generated and the algorithm in general, I identified two common cases where the algorithm could be simplified:
Additionally, I noticed that pretty much every use of a bit shift in the original implementation was guarded by an if clause that ensured the shift value did not cause UB. However, these checks introduce complexity and branching that can impact performance. I was able to remove these if statements by carefully considering what all the possible shift values can be. Note, I'm working with the assumption that 0 < w <= numeric_limits<result_type>::digits, which appears to be a safe assumption to make about w given how
All together, these changes provide minor performance improvements for sampling large ranges and significant performance improvements for sampling small ranges. The only other thing to note with this is that it removes the specific (and optimized) case for when Rp = 0 (representing R = 2^WDt) that was present previously. However, the optimizations for power-of-two values for R apply to this case as well, and there doesn't appear to be much benefit to trying to further optimize this case. This change also reuses the overload for Patch is 29.90 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/140161.diff 3 Files Affected:
diff --git a/libcxx/include/__random/uniform_int_distribution.h b/libcxx/include/__random/uniform_int_distribution.h
index fa2c33755b739..31a2364dc56f9 100644
--- a/libcxx/include/__random/uniform_int_distribution.h
+++ b/libcxx/include/__random/uniform_int_distribution.h
@@ -64,7 +64,7 @@ class __independent_bits_engine {
_LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w);
// generating functions
- _LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, _Rp != 0>()); }
+ _LIBCPP_HIDE_FROM_ABI result_type operator()() { return __eval(integral_constant<bool, (_Rp & (_Rp - 1)) != 0>()); }
private:
_LIBCPP_HIDE_FROM_ABI result_type __eval(false_type);
@@ -74,49 +74,63 @@ class __independent_bits_engine {
template <class _Engine, class _UIntType>
__independent_bits_engine<_Engine, _UIntType>::__independent_bits_engine(_Engine& __e, size_t __w)
: __e_(__e), __w_(__w) {
- __n_ = __w_ / __m + (__w_ % __m != 0);
- __w0_ = __w_ / __n_;
- if (_Rp == 0)
- __y0_ = _Rp;
- else if (__w0_ < _WDt)
- __y0_ = (_Rp >> __w0_) << __w0_;
- else
- __y0_ = 0;
- if (_Rp - __y0_ > __y0_ / __n_) {
- ++__n_;
+ if (__w_ <= __m) {
+ __n_ = __n0_ = 1;
+ __w0_ = __w_;
+ __mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
+ __y0_ = __y1_ = _Rp & ~__mask0_;
+ } else {
+ __n_ = (__w_ + __m - 1) / __m;
__w0_ = __w_ / __n_;
- if (__w0_ < _WDt)
- __y0_ = (_Rp >> __w0_) << __w0_;
- else
- __y0_ = 0;
+ __mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
+ __y0_ = __y1_ = _Rp & ~__mask0_;
+ if _LIBCPP_CONSTEXPR_SINCE_CXX17 ((_Rp & (_Rp - 1)) != 0) {
+ if (_Rp - __y0_ > __y0_ / __n_) {
+ ++__n_;
+ __w0_ = __w_ / __n_;
+ __mask0_ = __mask1_ = ~_Engine_result_type(0) >> (_EDt - __w0_);
+ __y0_ = __y1_ = _Rp & ~__mask0_;
+ }
+ }
+ size_t __n1_ = __w_ % __n_;
+ __n0_ = __n_ - __n1_;
+ if (__n1_ > 0) {
+ __mask1_ = ~_Engine_result_type(0) >> (_EDt - (__w0_ + 1));
+ __y1_ = _Rp & ~__mask1_;
+ }
}
- __n0_ = __n_ - __w_ % __n_;
- if (__w0_ < _WDt - 1)
- __y1_ = (_Rp >> (__w0_ + 1)) << (__w0_ + 1);
- else
- __y1_ = 0;
- __mask0_ = __w0_ > 0 ? _Engine_result_type(~0) >> (_EDt - __w0_) : _Engine_result_type(0);
- __mask1_ = __w0_ < _EDt - 1 ? _Engine_result_type(~0) >> (_EDt - (__w0_ + 1)) : _Engine_result_type(~0);
}
template <class _Engine, class _UIntType>
inline _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(false_type) {
- return static_cast<result_type>(__e_() & __mask0_);
+ result_type __sp = (__e_() - _Engine::min()) & __mask0_;
+ for (size_t __k = 1; __k < __n0_; ++__k) {
+ __sp <<= __w0_;
+ __sp += (__e_() - _Engine::min()) & __mask0_;
+ }
+ for (size_t __k = __n0_; __k < __n_; ++__k) {
+ __sp <<= __w0_ + 1;
+ __sp += (__e_() - _Engine::min()) & __mask1_;
+ }
+ return __sp;
}
template <class _Engine, class _UIntType>
_UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) {
- const size_t __w_rt = numeric_limits<result_type>::digits;
- result_type __sp = 0;
- for (size_t __k = 0; __k < __n0_; ++__k) {
+ result_type __sp;
+ {
+ _Engine_result_type __u;
+ do {
+ __u = __e_() - _Engine::min();
+ } while (__u >= __y0_);
+ __sp = __u & __mask0_;
+ }
+ for (size_t __k = 1; __k < __n0_; ++__k) {
_Engine_result_type __u;
do {
__u = __e_() - _Engine::min();
} while (__u >= __y0_);
- if (__w0_ < __w_rt)
- __sp <<= __w0_;
- else
- __sp = 0;
+ __sp <<= __w0_;
__sp += __u & __mask0_;
}
for (size_t __k = __n0_; __k < __n_; ++__k) {
@@ -124,10 +138,7 @@ _UIntType __independent_bits_engine<_Engine, _UIntType>::__eval(true_type) {
do {
__u = __e_() - _Engine::min();
} while (__u >= __y1_);
- if (__w0_ < __w_rt - 1)
- __sp <<= __w0_ + 1;
- else
- __sp = 0;
+ __sp <<= __w0_ + 1;
__sp += __u & __mask1_;
}
return __sp;
@@ -218,9 +229,9 @@ typename uniform_int_distribution<_IntType>::result_type uniform_int_distributio
typedef __independent_bits_engine<_URNG, _UIntType> _Eng;
if (__rp == 0)
return static_cast<result_type>(_Eng(__g, __dt)());
- size_t __w = __dt - std::__countl_zero(__rp) - 1;
- if ((__rp & (numeric_limits<_UIntType>::max() >> (__dt - __w))) != 0)
- ++__w;
+ size_t __w = __dt - std::__countl_zero(__rp);
+ if ((__rp & (__rp - 1)) == 0)
+ return static_cast<result_type>(_Eng(__g, __w - 1)() + __p.a());
_Eng __e(__g, __w);
_UIntType __u;
do {
diff --git a/libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp b/libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp
new file mode 100644
index 0000000000000..6ddb309c5192d
--- /dev/null
+++ b/libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp
@@ -0,0 +1,36 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// UNSUPPORTED: c++03
+
+#include <benchmark/benchmark.h>
+#include <cstdint>
+#include <random>
+
+template <typename Eng, std::uint64_t Max>
+static void bm_rand_uni_int(benchmark::State& state) {
+ Eng eng;
+ std::uniform_int_distribution<std::uint64_t> dist(1ull, Max);
+ for (auto _ : state) {
+ benchmark::DoNotOptimize(dist(eng));
+ }
+}
+
+// n = 1
+BENCHMARK(bm_rand_uni_int<std::minstd_rand0, 1ull << 20>);
+BENCHMARK(bm_rand_uni_int<std::ranlux24_base, 1ull << 20>);
+
+// n = 2, n0 = 2
+BENCHMARK(bm_rand_uni_int<std::minstd_rand0, 1ull << 40>);
+BENCHMARK(bm_rand_uni_int<std::ranlux24_base, 1ull << 40>);
+
+// n = 2, n0 = 1
+BENCHMARK(bm_rand_uni_int<std::minstd_rand0, 1ull << 41>);
+BENCHMARK(bm_rand_uni_int<std::ranlux24_base, 1ull << 41>);
+
+BENCHMARK_MAIN();
diff --git a/libcxx/test/std/numerics/rand/rand.dist/rand.dist.uni/rand.dist.uni.int/output.pass.cpp b/libcxx/test/std/numerics/rand/rand.dist/rand.dist.uni/rand.dist.uni.int/output.pass.cpp
new file mode 100644
index 0000000000000..61af217c430eb
--- /dev/null
+++ b/libcxx/test/std/numerics/rand/rand.dist/rand.dist.uni/rand.dist.uni.int/output.pass.cpp
@@ -0,0 +1,651 @@
+//===----------------------------------------------------------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// UNSUPPORTED: c++03
+
+// <random>
+
+#include <array>
+#include <random>
+#include <cassert>
+
+#include "test_macros.h"
+
+constexpr std::array<std::uint32_t, 1024> minstd_low_results = {
+ 0, 0, 1, 1, 4, 0, 2, 0, 8, 2, 2, 10, 12, 8, 2, 8, 14, 14, 3, 0, 3, 19, 19,
+ 18, 4, 19, 26, 23, 22, 3, 12, 2, 32, 21, 16, 15, 6, 34, 15, 28, 27, 8, 27, 20, 2, 31,
+ 6, 35, 26, 2, 12, 29, 16, 38, 19, 28, 17, 22, 0, 27, 43, 35, 61, 14, 34, 15, 16, 59, 1,
+ 11, 8, 33, 18, 65, 47, 22, 21, 34, 34, 18, 14, 32, 52, 12, 34, 78, 79, 46, 84, 23, 25, 81,
+ 11, 46, 18, 91, 0, 61, 41, 75, 3, 42, 98, 18, 89, 81, 9, 21, 89, 33, 74, 42, 3, 22, 62,
+ 68, 70, 98, 74, 97, 114, 2, 45, 110, 32, 31, 22, 122, 34, 99, 78, 64, 50, 22, 26, 93, 67, 116,
+ 74, 133, 109, 104, 69, 54, 33, 68, 32, 28, 19, 115, 104, 9, 29, 13, 62, 1, 139, 84, 8, 15, 124,
+ 133, 73, 98, 62, 44, 87, 117, 26, 20, 31, 164, 14, 20, 125, 102, 73, 35, 132, 95, 92, 27, 138, 0,
+ 116, 25, 136, 26, 158, 170, 51, 163, 156, 116, 183, 78, 125, 161, 89, 30, 162, 65, 150, 121, 51, 40, 46,
+ 199, 107, 155, 1, 138, 44, 144, 206, 31, 51, 132, 164, 75, 66, 40, 6, 144, 148, 120, 116, 71, 120, 14,
+ 77, 130, 69, 207, 119, 52, 110, 207, 154, 79, 140, 25, 7, 139, 77, 42, 163, 107, 181, 120, 79, 57, 90,
+ 184, 248, 168, 107, 18, 72, 77, 214, 135, 194, 40, 123, 212, 25, 61, 58, 40, 39, 6, 238, 24, 91, 166,
+ 151, 77, 231, 266, 180, 274, 14, 27, 113, 215, 185, 35, 60, 31, 92, 206, 185, 213, 8, 167, 22, 163, 92,
+ 63, 187, 245, 156, 92, 280, 65, 110, 56, 58, 208, 287, 181, 180, 25, 145, 19, 190, 251, 242, 127, 7, 28,
+ 249, 274, 206, 283, 174, 77, 217, 280, 139, 257, 235, 120, 251, 78, 299, 250, 96, 184, 19, 259, 161, 57, 91,
+ 14, 247, 283, 202, 233, 259, 333, 103, 316, 260, 182, 3, 253, 38, 98, 146, 39, 56, 279, 197, 311, 278, 330,
+ 318, 276, 65, 67, 337, 370, 278, 14, 127, 353, 250, 360, 73, 144, 299, 260, 60, 219, 281, 49, 185, 171, 191,
+ 269, 182, 250, 222, 219, 14, 160, 321, 91, 11, 20, 100, 363, 320, 180, 228, 91, 221, 3, 184, 405, 200, 241,
+ 127, 397, 323, 220, 171, 67, 111, 14, 384, 388, 106, 389, 195, 364, 149, 310, 196, 252, 63, 228, 223, 264, 429,
+ 90, 125, 46, 213, 310, 269, 69, 55, 174, 382, 400, 101, 386, 354, 111, 98, 318, 401, 55, 148, 315, 294, 281,
+ 114, 432, 188, 221, 146, 165, 17, 148, 109, 377, 453, 211, 128, 136, 362, 246, 264, 81, 358, 219, 453, 297, 2,
+ 300, 345, 249, 307, 26, 92, 275, 455, 215, 277, 370, 58, 343, 408, 441, 199, 330, 480, 341, 57, 185, 211, 291,
+ 283, 198, 378, 506, 2, 478, 400, 261, 405, 138, 409, 145, 19, 200, 433, 327, 63, 259, 218, 304, 215, 350, 395,
+ 415, 404, 3, 442, 340, 89, 284, 82, 515, 63, 39, 374, 78, 410, 354, 415, 388, 229, 430, 234, 188, 483, 136,
+ 548, 306, 400, 9, 27, 510, 326, 93, 317, 12, 332, 232, 214, 58, 35, 310, 1, 515, 175, 21, 387, 67, 400,
+ 474, 191, 47, 573, 297, 44, 260, 440, 33, 28, 431, 437, 207, 506, 235, 498, 386, 526, 230, 349, 118, 185, 510,
+ 282, 520, 171, 284, 360, 47, 81, 83, 424, 188, 101, 0, 60, 433, 428, 277, 105, 190, 70, 574, 478, 519, 409,
+ 400, 133, 428, 548, 308, 302, 372, 243, 220, 564, 201, 565, 77, 179, 391, 339, 309, 333, 426, 239, 277, 51, 412,
+ 632, 643, 468, 276, 112, 498, 80, 546, 107, 508, 632, 622, 436, 390, 555, 152, 366, 248, 652, 147, 329, 279, 496,
+ 107, 17, 464, 153, 654, 365, 281, 619, 163, 502, 333, 616, 96, 9, 382, 392, 225, 248, 566, 412, 50, 195, 219,
+ 291, 468, 690, 97, 120, 681, 521, 324, 367, 286, 449, 31, 6, 242, 669, 545, 546, 91, 297, 324, 122, 158, 223,
+ 336, 446, 93, 357, 669, 601, 521, 678, 473, 712, 116, 311, 531, 17, 435, 340, 516, 148, 34, 301, 449, 145, 575,
+ 55, 287, 594, 210, 262, 478, 219, 301, 574, 699, 199, 485, 564, 341, 217, 665, 435, 281, 665, 751, 627, 673, 434,
+ 64, 417, 35, 521, 438, 474, 503, 93, 37, 723, 726, 395, 415, 398, 178, 495, 96, 520, 714, 394, 248, 408, 61,
+ 307, 600, 10, 582, 544, 389, 548, 220, 99, 238, 599, 424, 263, 422, 270, 508, 678, 22, 121, 504, 368, 241, 662,
+ 118, 776, 552, 765, 565, 154, 532, 220, 537, 192, 321, 99, 74, 369, 358, 538, 170, 546, 724, 44, 138, 712, 195,
+ 426, 771, 352, 443, 475, 431, 130, 600, 46, 262, 278, 53, 317, 536, 7, 95, 77, 417, 524, 223, 220, 572, 730,
+ 243, 503, 348, 599, 448, 606, 132, 790, 814, 796, 714, 313, 253, 686, 206, 152, 862, 140, 10, 699, 278, 159, 31,
+ 327, 421, 611, 205, 675, 870, 633, 382, 311, 695, 879, 529, 405, 357, 850, 176, 884, 826, 356, 12, 530, 35, 217,
+ 660, 102, 668, 525, 456, 122, 227, 145, 553, 367, 359, 575, 331, 641, 211, 221, 886, 888, 730, 321, 15, 78, 303,
+ 482, 30, 397, 236, 266, 222, 171, 147, 430, 677, 246, 470, 371, 616, 89, 325, 236, 748, 622, 444, 669, 783, 218,
+ 789, 4, 655, 657, 128, 398, 490, 190, 459, 488, 309, 378, 206, 243, 21, 167, 690, 402, 308, 777, 1, 374, 810,
+ 19, 254, 406, 384, 602, 611, 703, 646, 211, 71, 424, 823, 910, 529, 820, 743, 249, 248, 297, 371, 982, 104, 941,
+ 223, 520, 991, 772, 762, 922, 107, 173, 333, 980, 447, 945, 358, 913, 1003, 929, 806, 965, 567, 566, 3, 817, 442,
+ 227, 650, 107, 613, 303, 431, 991, 752, 402, 310, 289, 493,
+};
+
+constexpr std::array<std::uint32_t, 1024> mt19937_low_results{
+ 0, 0, 2, 2, 1, 4, 5, 1, 3, 5, 4, 10, 0, 7, 4, 9, 1, 17, 10, 18, 13, 17, 14,
+ 7, 9, 13, 11, 8, 12, 26, 15, 0, 5, 12, 27, 6, 2, 13, 11, 18, 39, 25, 2, 2, 36, 36,
+ 9, 25, 9, 26, 33, 7, 51, 46, 1, 8, 20, 40, 38, 3, 43, 12, 18, 9, 14, 63, 8, 11, 51,
+ 24, 17, 42, 54, 39, 16, 69, 51, 33, 8, 67, 62, 69, 35, 35, 9, 39, 65, 86, 15, 42, 52, 60,
+ 46, 27, 74, 2, 4, 10, 36, 74, 3, 90, 21, 39, 57, 92, 97, 106, 10, 29, 5, 31, 39, 23, 79,
+ 95, 10, 22, 40, 51, 57, 66, 113, 120, 80, 64, 68, 1, 35, 83, 30, 53, 66, 65, 100, 89, 34, 14,
+ 99, 117, 135, 77, 72, 8, 52, 2, 97, 68, 124, 130, 72, 109, 93, 15, 50, 130, 70, 104, 73, 48, 32,
+ 157, 128, 136, 104, 63, 42, 126, 108, 8, 164, 87, 104, 36, 94, 135, 91, 129, 96, 43, 8, 29, 141, 108,
+ 78, 72, 20, 77, 48, 115, 14, 55, 68, 144, 148, 41, 148, 100, 4, 185, 2, 85, 182, 3, 91, 45, 161,
+ 195, 48, 188, 113, 20, 41, 4, 160, 204, 97, 185, 156, 64, 156, 127, 162, 120, 28, 33, 65, 171, 59, 210,
+ 31, 84, 227, 67, 116, 186, 39, 203, 106, 219, 38, 144, 49, 77, 147, 17, 145, 50, 235, 33, 117, 115, 77,
+ 137, 51, 99, 12, 68, 189, 60, 130, 77, 240, 87, 25, 203, 182, 34, 71, 126, 99, 182, 80, 87, 147, 15,
+ 16, 175, 249, 237, 219, 160, 272, 236, 152, 151, 144, 269, 11, 249, 78, 127, 68, 244, 47, 143, 83, 296, 157,
+ 299, 54, 288, 293, 169, 192, 194, 299, 253, 98, 187, 36, 226, 145, 279, 300, 245, 65, 151, 100, 186, 88, 187,
+ 206, 120, 117, 109, 117, 265, 264, 46, 227, 231, 55, 166, 27, 101, 248, 271, 39, 221, 16, 99, 204, 309, 162,
+ 56, 234, 171, 291, 15, 218, 177, 87, 97, 225, 54, 270, 232, 293, 278, 310, 321, 61, 144, 0, 19, 141, 185,
+ 106, 88, 357, 95, 104, 59, 104, 103, 10, 301, 243, 18, 377, 221, 317, 218, 49, 24, 110, 179, 369, 43, 363,
+ 179, 242, 264, 307, 152, 287, 382, 350, 132, 268, 123, 81, 347, 136, 3, 273, 149, 282, 11, 47, 358, 269, 19,
+ 253, 259, 283, 339, 88, 193, 104, 164, 249, 310, 40, 352, 334, 275, 243, 197, 39, 71, 52, 389, 95, 386, 18,
+ 340, 397, 318, 166, 138, 72, 110, 171, 48, 66, 247, 436, 206, 394, 431, 156, 433, 148, 345, 160, 298, 381, 347,
+ 95, 279, 455, 447, 353, 163, 334, 274, 185, 219, 170, 275, 101, 218, 106, 123, 8, 339, 257, 201, 399, 396, 209,
+ 290, 335, 300, 210, 291, 293, 269, 422, 212, 110, 177, 137, 292, 444, 385, 57, 258, 204, 277, 372, 56, 188, 221,
+ 352, 179, 318, 10, 457, 157, 38, 480, 499, 47, 406, 24, 247, 61, 341, 316, 482, 113, 241, 416, 460, 500, 338,
+ 368, 18, 39, 423, 461, 110, 326, 330, 4, 150, 404, 501, 452, 72, 90, 100, 110, 171, 299, 152, 36, 491, 292,
+ 193, 227, 488, 188, 299, 428, 98, 214, 168, 468, 175, 248, 73, 475, 518, 152, 288, 212, 63, 240, 354, 362, 313,
+ 282, 168, 488, 465, 410, 332, 386, 300, 378, 16, 73, 134, 133, 263, 521, 67, 500, 370, 487, 298, 201, 526, 583,
+ 9, 328, 566, 375, 306, 486, 563, 406, 43, 79, 534, 606, 84, 417, 68, 412, 531, 560, 222, 358, 372, 19, 345,
+ 166, 203, 59, 250, 614, 216, 141, 501, 304, 462, 19, 539, 46, 602, 80, 337, 325, 574, 596, 275, 315, 163, 66,
+ 166, 341, 193, 176, 535, 630, 579, 434, 607, 62, 65, 87, 487, 162, 627, 362, 312, 321, 177, 513, 525, 625, 512,
+ 530, 313, 641, 112, 369, 90, 503, 534, 328, 177, 636, 129, 372, 290, 368, 187, 3, 32, 526, 405, 536, 505, 446,
+ 135, 611, 315, 147, 267, 42, 280, 284, 120, 230, 527, 242, 564, 162, 510, 120, 289, 587, 267, 594, 67, 316, 437,
+ 469, 80, 259, 349, 612, 673, 507, 715, 289, 614, 148, 257, 659, 588, 189, 336, 386, 170, 74, 157, 645, 667, 93,
+ 301, 139, 193, 68, 624, 235, 620, 116, 112, 327, 706, 125, 141, 425, 2, 459, 378, 540, 538, 351, 19, 76, 322,
+ 511, 558, 471, 448, 717, 67, 685, 283, 89, 402, 233, 570, 230, 216, 159, 68, 663, 5, 297, 146, 720, 195, 482,
+ 15, 674, 256, 384, 494, 468, 435, 210, 191, 742, 356, 696, 418, 689, 49, 537, 406, 40, 437, 598, 290, 251, 389,
+ 478, 513, 378, 102, 162, 543, 211, 301, 78, 503, 15, 393, 142, 85, 41, 617, 126, 776, 639, 304, 24, 586, 573,
+ 67, 101, 203, 225, 7, 197, 498, 687, 812, 548, 96, 0, 518, 550, 93, 653, 103, 596, 667, 519, 448, 565, 298,
+ 13, 214, 143, 575, 244, 776, 797, 256, 609, 104, 211, 455, 303, 573, 659, 314, 397, 53, 133, 695, 189, 310, 349,
+ 770, 345, 645, 192, 700, 347, 603, 619, 302, 818, 64, 29, 656, 822, 59, 570, 530, 603, 183, 804, 360, 388, 297,
+ 62, 514, 765, 885, 625, 735, 883, 356, 562, 97, 335, 652, 398, 15, 527, 69, 181, 169, 165, 885, 191, 23, 282,
+ 743, 201, 783, 125, 427, 304, 742, 530, 61, 531, 289, 654, 739, 545, 511, 576, 587, 219, 877, 638, 812, 319, 535,
+ 71, 210, 469, 223, 259, 942, 183, 698, 243, 589, 660, 488, 707, 668, 710, 727, 184, 335, 931, 172, 412, 538, 514,
+ 708, 410, 610, 649, 200, 220, 783, 619, 293, 143, 146, 602, 798, 173, 516, 339, 213, 923, 976, 728, 706, 175, 534,
+ 115, 396, 511, 162, 832, 832, 537, 237, 335, 738, 663, 750, 968, 148, 166, 853, 673, 234, 580, 340, 467, 423, 851,
+ 805, 99, 233, 380, 119, 908, 956, 861, 510, 515, 670, 400,
+};
+
+constexpr std::array<std::uint64_t, 256> minstd_high_results = {
+ 0,
+ 1619687266273,
+ 547825395978620,
+ 823097292413806,
+ 591548725059916,
+ 235469987414414,
+ 601800132705667,
+ 1783241889905080,
+ 2051816675310482,
+ 1841237109167236,
+ 1209786246595226,
+ 2146293784604983,
+ 1229955559442691,
+ 1876783466283202,
+ 3748367557330986,
+ 169686692421583,
+ 1256383104102300,
+ 470231007626659,
+ 1652887366564738,
+ 4196040995203280,
+ 3176751938948606,
+ 5269566674535580,
+ 1094497494027350,
+ 1232879032352347,
+ 1199993928302435,
+ 4760369655010062,
+ 5823018175849615,
+ 969423329151101,
+ 6066196886144982,
+ 7750268748163783,
+ 4727729787832968,
+ 7265102324616433,
+ 1445295389123474,
+ 8804124791765807,
+ 6260374552080761,
+ 5903294081544854,
+ 9635607377896799,
+ 2177954562107170,
+ 7572616078392338,
+ 2091209451121912,
+ 5384481869817122,
+ 5297548383674487,
+ 282525371298395,
+ 1468926476589939,
+ 4593826625397632,
+ 8764944894070205,
+ 6582259066132137,
+ 10817572468301315,
+ 12791992773334186,
+ 2711786844508690,
+ 2897537567686361,
+ 13141477009672841,
+ 1202645279322585,
+ 328570709834390,
+ 3188867030102980,
+ 15424319378088546,
+ 14964658462672242,
+ 14788076765788142,
+ 8389690444079391,
+ 15062693868653306,
+ 10029509229282147,
+ 6170135033813326,
+ 5038120660983274,
+ 1480947715804978,
+ 9015417911070915,
+ 2331337575098984,
+ 2626536113675332,
+ 10420116761676403,
+ 16511463003848713,
+ 2468888236479408,
+ 9039961620284990,
+ 11280895620331521,
+ 15594410759685579,
+ 4323283785890172,
+ 5190333082152290,
+ 9392045150347732,
+ 4946974037897421,
+ 6491750398295843,
+ 10983166354059997,
+ 9818088029369736,
+ 7500717649486573,
+ 17871595299459242,
+ 11158368906025884,
+ 15186394147139234,
+ 178703193469736,
+ 16878600282746311,
+ 10970923764139982,
+ 7081562318005570,
+ 19618321890284550,
+ 22178866902166527,
+ 1755013744089463,
+ 19708148762225036,
+ 10769242287913355,
+ 4919671137968235,
+ 24396405042595295,
+ 5428529973083981,
+ 11413353950692987,
+ 5867051170659865,
+ 18900758747343400,
+ 10050345523069479,
+ 25182489846569198,
+ 26969078486129243,
+ 7055131302411430,
+ 23737397081026125,
+ 16662473560367370,
+ 27406775304400255,
+ 21935529703219726,
+ 22922626937690462,
+ 18151312663083214,
+ 17151606492716362,
+ 13868059936509348,
+ 23708084429433212,
+ 2...
[truncated]
|
I included two test files I made while validating this change as part of this PR. The first is a test file that tests the output of The second is a benchmark for I ran this test on my machine (i7 8700K, WSL on Windows 10) with and without the changes made in this PR, took the median time of 16 repetitions, and compiled the result in the table below. This should give a good idea of the impact this change will have.
I kept these test file additions in its own commit preceding the changes to the libc++ header. I did this because a) it makes it easier to perform the before-and-after tests, and b) I'm not sure if these test files should be merged. While the benchmark may be useful, the other test file likely isn't useful, since it doesn't test anything related to standard conformance or a known regression. I'll leave the tests as they are now for code review. However, let me know if you want me to squash the commits or remove any of the test files, and I'll do it. |
✅ With the latest revision this PR passed the C/C++ code formatter. |
a16c72f
to
d2b4d65
Compare
You said that you avoided changing the algorithm. Do you think we should use a different algorithm that has improved performance? Given that this is a random distribution I don't think it's a huge problem to change what exactly it produces. |
Short Answer: Maybe, but what to replace it with is a non-trivial problem. Also, this PR addresses a lot of the concern I have with the current method. For context, I'm going to reference the algorithms as described in this blog post on the PCG website, because it's one of the best write-ups of the various resampling methods I've found. Note, however, that the code snippets (and to a lesser extent, the algorithms themselves) use a much more narrow contract than The algorithm used here is basically the "Bitmask" method described there. This method has a critical trade-off, where it avoids having to do a more expensive division (or multiplication), but it generally performs significantly more sample rejections. Whether this performs better or worse depends on multiple factors, such as how fast the PRNG algorithm is, how fast division is on the device, whether the CPU has a CLZ instruction (or if it has to be emulated in software), and what kinds of ranges are actually sampled. There is one more thing to keep in mind with this though: the "Bitmask" method assumes the PRNG is a power-of-two PRNG. Since There are a lot of alternative options if we did want to change the algorithm. From what I've seen, libstdc++ uses a combination of the "Debiased Integer Multiplication — Lemire's Method" and "Division with Rejection (Unbiased)" methods depending on the PRNG and available instruction set. There's also the different modulus methods, which might be worth considering. It's also worth keeping in mind that the current method provides the widening method (where the output range is larger than the input range) basically for free, while switching methods will require deciding on a method for that as well. If there's interest in possibly changing the algorithm used, I'd be open to having and participating in a larger discussion around it. It should probably include performing performance tests in a wider range of hardware than I personally have access to, as we probably shouldn't just make assumptions about all hardware based on modern high-end hardware. In the meantime, I think it's worth considering this PR on its own, as it addresses a lot of the concern I have with the current method. |
Oh, and on the note of potentially making breaking changes, there is another possible optimization that might be worth considering, but instead of breaking expected behavior, it would break ABI. The biggest performance cost of this implementation (particularly for larger ranges) is determining the sampling parameters in the constructor for There would basically be two options for doing this. Firstly, we could add an instance of I didn't include this type of change in this PR because I was worried about complications from these kinds of breaking changes (with ABI changes being particularly complicated). If we're willing to make breaking changes, this is a change worth considering, but for now I'm gonna keep this PR as non-breaking, as it seems like the safer option. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
First of all, thanks for the amazingly detailed writeup. This is very useful to vet this change since most of us are not super familiar with the details of our <random>
implementation, which was written early in the library's lifetime and not touched much afterwards.
I will echo @philnik777 's comment that changing the exact values returned by this type is not a problem, as long as we retain a similarly good (or better) uniform distribution. Changing the ABI of the type is impossible to do by default, however we can definitely do it behind an ABI macro if there is a sufficiently large benefit. That way, adopters of unstable ABI features can opt into the change, and ABI-stable vendors creating new ABIs (typically when introducing a new architecture) can also benefit from it.
About this change itself, I pulled your patch and ran the benchmark on my machine (arm64 mac studio). The results are:
$ libcxx/utils/libcxx-compare-benchmarks build/default build/candidate libcxx/test/benchmarks/numeric/rand.uni.int.bench.cpp
Comparing build/default/libcxx/test/benchmarks/numeric/Output/rand.uni.int.bench.cpp.dir/benchmark-result.json to build/candidate/libcxx/test/benchmarks/numeric/Output/rand.uni.int.bench.cpp.dir/benchmark-result.json
Benchmark Time CPU Time Old Time New CPU Old CPU New
----------------------------------------------------------------------------------------------------------------------------------------------
bm_rand_uni_int<std::minstd_rand0, 1ull << 20> -0.0434 -0.0434 7 7 7 7
bm_rand_uni_int<std::ranlux24_base, 1ull << 20> -0.1183 -0.1183 10 9 10 9
bm_rand_uni_int<std::minstd_rand0, 1ull << 40> -0.0177 -0.0176 13 12 13 12
bm_rand_uni_int<std::ranlux24_base, 1ull << 40> -0.1081 -0.1081 16 14 16 14
bm_rand_uni_int<std::minstd_rand0, 1ull << 41> -0.1236 -0.1246 14 13 14 13
bm_rand_uni_int<std::ranlux24_base, 1ull << 41> -0.1553 -0.1561 17 15 17 15
OVERALL_GEOMEAN -0.0957 -0.0959 0 0 0 0
So overall, I'm seeing roughly 10% improvement, which is great. Do you have an hypothesis to explain why you were seeing up to 2.83x on some benchmarks and I am seeing more modest numbers?
I think the next step for landing this patch would be to address my minor review comments, rebase this (there are no conflicts) and this is probably fine to land. Any further improvement in the algorithm or potential ABI breaking change should be discussed in a follow-up. Let's not make perfection the enemy of good!
libcxx/test/std/numerics/rand/rand.dist/rand.dist.uni/rand.dist.uni.int/output.pass.cpp
Outdated
Show resolved
Hide resolved
@@ -64,7 +64,7 @@ class __independent_bits_engine { | |||
_LIBCPP_HIDE_FROM_ABI __independent_bits_engine(_Engine& __e, size_t __w); | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated to this file: let's add a release note for this optimization. You can add it in libcxx/docs/ReleaseNotes/21.rst
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a suggestion on how to word this exactly? Should I just put something like this under "Improvements and New Features"?
The implementation of
std::uniform_int_distribution
has been optimized for common cases.
4e418c8
to
76d24f5
Compare
Ok, I've went and addressed (most) of the review comments and rebased against the main branch. I removed the output test case, since that was only intended to be temporary anyways, but kept the benchmark on its own commit for a bit. The testing results you got has me wanting to try running the benchmark on some other devices to try and get a more complete picture of what the results of this change are.
The large improvements I saw were for the case where n = 1, with the main optimization there being the removal of two division calculations (to calculate the values for |
5f25497
to
5f41307
Compare
I went ahead and updated the benchmark program. I had written the benchmark with this change in mind, which primarily focused on changes to The test cases work out like this, where I was also able to get
|
After seeing the benchmark results from the Apple Silicon mac, I wanted to run the benchmark on a different ARM64 computer. My first choice would have been a Raspberry Pi (or similar), but I don't have one of those on hand. However, I did end up getting Linux running on my Nintendo Switch, which is also ARM64. This is the results I got from that:
It's not as dramatic of a difference as I saw on my 8700K, but it is certainly more significant than the results on Apple Silicon. With this, I'm confident in my assessment that the difference in results is due to the performance difference of integer division between the processors. |
I noticed that the implementation for
std::uniform_int_distribution
in libc++ often has worse performance than other implementations. While it might be useful to change the implementation to one that rejects less samples, such a change would result in the output of the distribution changing, which could break programs dependent on the behavior. Instead, this PR attempts to optimize the existing implementation without changing the output of the distribution.In libc++,
std::uniform_int_distribution
uses__independent_bits_engine
to produce n-bit random numbers. The implementation of__independent_bits_engine
is based on the specification forstd::independent_bits_engine
, except that the value for n is specified as a runtime value. However, the design ofstd::independent_bits_engine
was optimized for the case where n is known at compile time. The algorithm used requires calculating a number of values, whichstd::independent_bits_engine
can calculate at compile time, avoiding any runtime performance penalty. However, the use of this algorithm forstd::uniform_int_distribution
results in these values having to be calculated for every output generated, which ends up accounting for a significant portion of the algorithm's runtime.While analyzing the definitions of the values that need to be generated and the algorithm in general, I identified two common cases where the algorithm could be simplified:
__eval
that limits branching and doesn't read the values for y0 or y1 (allowing the optimizer to potentially eliminate the variables entirely). Additionally, the check for the condition R - y0 <= floor(y0 / n) simplifies to 0 <= floor(R / n), which is always true, meaning we can skip this check entirely when generating the required values.Additionally, I noticed that pretty much every use of a bit shift in the original implementation was guarded by an if clause that ensured the shift value did not cause UB. However, these checks introduce complexity and branching that can impact performance. I was able to remove these if statements by carefully considering what all the possible shift values can be. Note, I'm working with the assumption that 0 < w <= numeric_limits<result_type>::digits, which appears to be a safe assumption to make about w given how
__independent_bits_engine
is used bystd::uniform_int_distribution
.__eval
, the case where n > 1 will never result in a shift value that results in UB. This is because the relation n0*w0 + (n - n0)*(w0 + 1) = w gives that w0 < w, and if n > n0, w0 + 1 < w. This means the only case that needs to be handled is when n = 1. This is done by manually unrolling the k = 0 iteration and removing the shift step from it. This makes the assumptions that n > 0 and n0 > 0, but this is always true given w > 0 and n - n0 < n respectively (the latter being given by the "mod n" in the definition of n0).All together, these changes provide minor performance improvements for sampling large ranges and significant performance improvements for sampling small ranges. The only other thing to note with this is that it removes the specific (and optimized) case for when Rp = 0 (representing R = 2^WDt) that was present previously. However, the optimizations for power-of-two values for R apply to this case as well, and there doesn't appear to be much benefit to trying to further optimize this case. This change also reuses the overload for
__eval
previously used for the Rp = 0 special case for handling power-of-two values of R.